/**
 * This program is free software, you can redistribute it and/or modify.
 * Copyright (c) 2025 Huawei Technologies Co., Ltd.
 * This file is a part of the CANN Open Software.
 * Licensed under CANN Open Software License Agreement Version 2.0 (the "License").
 * Please refer to the License for details. You may not use this file except in compliance with the License.
 * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE.
 * See LICENSE in the root of the software repository for the full text of the License.
 */

/*!
 * \file test_aclnn_matmul_reduce_scatter.cpp
 * \brief
 */

 #include <thread>
 #include <iostream>
 #include <vector>
 #include "../op_host/op_api/aclnn_matmul_reduce_scatter.h"
 #include<unistd.h>
 
 #define CHECK_RET(cond, return_expr) \
     do {                             \
         if (!(cond)) {               \
             return_expr;             \
         }                            \
     } while (0)
 
 #define LOG_PRINT(message, ...)         \
     do {                                \
         printf(message, ##__VA_ARGS__); \
     } while(0)
 
 constexpr int DEV_NUM = 8;
 
 int64_t GetShapeSize(const std::vector<int64_t> &shape)
 {
     int64_t shape_size = 1;
     for (auto i : shape) {
         shape_size *= i;
     }
     return shape_size;
 }
 
 template<typename T>
 int CreateAclTensor(const std::vector<T> &hostData, const std::vector<int64_t> &shape, void **deviceAddr,
     aclDataType dataType, aclTensor **tensor)
 {
     auto size = GetShapeSize(shape) * sizeof(T);
     auto ret = aclrtMalloc(deviceAddr, size, ACL_MEM_MALLOC_HUGE_FIRST);
     CHECK_RET(ret == ACL_SUCCESS, LOG_PRINT("[ERROR] aclrtMalloc failed. ret: %d\n", ret); return ret);
     ret = aclrtMemcpy(*deviceAddr, size, hostData.data(), size, ACL_MEMCPY_HOST_TO_DEVICE);
     CHECK_RET(ret == ACL_SUCCESS, LOG_PRINT("[ERROR] aclrtMemcpy failed. ret: %d\n", ret); return ret);
     std::vector<int64_t> strides(shape.size(), 1);
     for (int64_t i = shape.size() - 2; i >= 0; i--) {
         strides[i] = shape[i +1] * strides[i + 1];
     }
     *tensor = aclCreateTensor(shape.data(), shape.size(), dataType, strides.data(), 0, aclFormat::ACL_FORMAT_ND,
         shape.data(), shape.size(), *deviceAddr);
     return 0;
 }
 
 struct Args {
     int rankId;
     HcclComm hcclComm;
     aclrtStream stream;
   };
 
 int launchOneThread_MmReduceScatter(Args &args)
 {
     int ret = aclrtSetDevice(args.rankId);
     CHECK_RET(ret == ACL_SUCCESS, LOG_PRINT("[ERROR] aclrtSetDevice failed. ret = %d \n", ret); return ret);
 
     char hcomName[128] = {0};
     ret = HcclGetCommName(args.hcclComm, hcomName);
     CHECK_RET(ret == ACL_SUCCESS, LOG_PRINT("[ERROR] HcclGetCommName failed. ERROR: %d\n", ret); return -1);
     LOG_PRINT("[INFO] rank = %d, hcomName = %s, stream = %p\n", args.rankId, hcomName, args.stream);
     std::vector<int64_t> x1Shape = {1024, 256};
     std::vector<int64_t> x2Shape = {256, 512};
     std::vector<int64_t> biasShape = {512};
     std::vector<int64_t> outShape = {1024 / DEV_NUM, 512};
     void *x1DeviceAddr = nullptr;
     void *x2DeviceAddr = nullptr;
     void *biasDeviceAddr = nullptr;
     void *outDeviceAddr = nullptr;
     aclTensor *x1 = nullptr;
     aclTensor *x2 = nullptr;
     aclTensor *bias = nullptr;
     aclTensor *out = nullptr;
 
     int64_t commTurn = 0;
     int64_t streamMode = 1;
     uint64_t workspaceSize = 0;
     aclOpExecutor *executor = nullptr;
     void *workspaceAddr = nullptr;
 
     long long x1ShapeSize = GetShapeSize(x1Shape);
     long long x2ShapeSize = GetShapeSize(x2Shape);
     long long biasShapeSize = GetShapeSize(biasShape);
     long long outShapeSize = GetShapeSize(outShape);
 
     std::vector<int16_t> x1HostData(x1ShapeSize, 0);
     std::vector<int16_t> x2HostData(x2ShapeSize, 0);
     std::vector<int16_t> biasHostData(biasShapeSize, 0);
     std::vector<int16_t> outHostData(outShapeSize, 0);
     // 创建tensor
     ret = CreateAclTensor(x1HostData, x1Shape, &x1DeviceAddr, aclDataType::ACL_FLOAT16, &x1);
     CHECK_RET(ret == ACL_SUCCESS, return ret);
     ret = CreateAclTensor(x2HostData, x2Shape, &x2DeviceAddr, aclDataType::ACL_FLOAT16, &x2);
     CHECK_RET(ret == ACL_SUCCESS, return ret);
     ret = CreateAclTensor(outHostData, outShape, &outDeviceAddr, aclDataType::ACL_FLOAT16, &out);
     CHECK_RET(ret == ACL_SUCCESS, return ret);
 
     // 调用第一阶段接口
     ret = aclnnMatmulReduceScatterGetWorkspaceSize(
         x1, x2, bias, hcomName, "sum", commTurn, streamMode, out, &workspaceSize, &executor);
     CHECK_RET(ret == ACL_SUCCESS,
         LOG_PRINT("[ERROR] aclnnMatmulReduceScatterGetWorkspaceSize failed. ret = %d \n", ret); return ret);
     // 根据第一阶段接口计算出的workspaceSize申请device内存
     if (workspaceSize > 0) {
         ret = aclrtMalloc(&workspaceAddr, workspaceSize, ACL_MEM_MALLOC_HUGE_FIRST);
         CHECK_RET(ret == ACL_SUCCESS, LOG_PRINT("[ERROR] aclrtMalloc workspace failed. ret = %d \n", ret); return ret);
     }
     // 调用第二阶段接口
     ret = aclnnMatmulReduceScatter(workspaceAddr, workspaceSize, executor, args.stream);
     CHECK_RET(ret == ACL_SUCCESS, LOG_PRINT("[ERROR] aclnnMatmulReduceScatter failed. ret = %d \n", ret); return ret);
     // （固定写法）同步等待任务执行结束
     ret = aclrtSynchronizeStreamWithTimeout(args.stream, 10000);
     CHECK_RET(ret == ACL_SUCCESS, LOG_PRINT("[ERROR] aclrtSynchronizeStreamWithTimeout failed. ret = %d \n", ret);
         return ret);
     LOG_PRINT("[INFO] device_%d aclnnMatmulReduceScatter execute successfully.\n", args.rankId);
     // 释放device资源，需要根据具体API的接口定义修改
     if (x1 != nullptr) {
         aclDestroyTensor(x1);
     }
     if (x2 != nullptr) {
         aclDestroyTensor(x2);
     }
     if (bias != nullptr) {
         aclDestroyTensor(bias);
     }
     if (out != nullptr) {
         aclDestroyTensor(out);
     }
     if (x1DeviceAddr != nullptr) {
         aclrtFree(x1DeviceAddr);
     }
     if (x2DeviceAddr != nullptr) {
         aclrtFree(x2DeviceAddr);
     }
     if (biasDeviceAddr != nullptr) {
         aclrtFree(biasDeviceAddr);
     }
     if (outDeviceAddr != nullptr) {
         aclrtFree(outDeviceAddr);
     }
     if (workspaceSize > 0) {
         aclrtFree(workspaceAddr);
     }
     ret = aclrtDestroyStream(args.stream);
     CHECK_RET(ret == ACL_SUCCESS, LOG_PRINT("[ERROR] aclrtDestroyStream failed. ret = %d \n", ret); return ret);
     ret = aclrtResetDevice(args.rankId);
     CHECK_RET(ret == ACL_SUCCESS, LOG_PRINT("[ERROR] aclrtResetDevice failed. ret = %d \n", ret); return ret);
     return 0;
 }
 
 int main(int argc, char *argv[])
 {
     int ret = aclInit(nullptr);
     CHECK_RET(ret == ACL_SUCCESS, LOG_PRINT("[ERROR] aclInit failed. ret = %d \n", ret); return ret);
     aclrtStream stream[DEV_NUM];
     for (uint32_t rankId = 0; rankId < DEV_NUM; rankId++) {
         ret = aclrtSetDevice(rankId);
         CHECK_RET(ret == ACL_SUCCESS, LOG_PRINT("[ERROR] aclrtSetDevice failed. ret = %d \n", ret); return ret);
         ret = aclrtCreateStream(&stream[rankId]);
         CHECK_RET(ret == ACL_SUCCESS, LOG_PRINT("[ERROR] aclrtCreateStream failed. ret = %d \n", ret); return ret);
     }
     int32_t devices[DEV_NUM];
     for (int i = 0; i < DEV_NUM; i++) {
         devices[i] = i;
     }
     // 初始化集合通信域
     HcclComm comms[DEV_NUM];
     ret = HcclCommInitAll(DEV_NUM, devices, comms);
     CHECK_RET(ret == ACL_SUCCESS, LOG_PRINT("[ERROR] HcclCommInitAll failed. ret = %d \n", ret); return ret);
 
     Args args[DEV_NUM];
     // 启动多线程
     std::vector<std::unique_ptr<std::thread>> threads(DEV_NUM);
     for (uint32_t rankId = 0; rankId < DEV_NUM; rankId++) {
         args[rankId].rankId = rankId;
         args[rankId].hcclComm = comms[rankId];
         args[rankId].stream = stream[rankId];
         threads[rankId].reset(new(std::nothrow) std::thread(&launchOneThread_MmReduceScatter, std::ref(args [rankId])));
     }
     for (uint32_t rankId = 0; rankId < DEV_NUM; rankId++) {
         threads[rankId]->join();
     }
     for (int i = 0; i < DEV_NUM; i++) {
         auto hcclRet = HcclCommDestroy(comms[i]);
         CHECK_RET(hcclRet == HCCL_SUCCESS, LOG_PRINT("[ERROR] HcclCommDestroy failed. ret = %d \n", ret); return -1);
     }
     aclFinalize();
     _exit(0);
 }